import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import numpy as np
from dorefanet import *
from typing import Type, Any, Callable, List, Optional, Tuple, Union
import math

stage_out_channel = [16] * 2 + [32] * 2 + [128] * 2 + [512] * 6 + [2048] * 2

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class firstconv3x3(nn.Module):
    def __init__(self, inp, oup, stride):
        super(firstconv3x3, self).__init__()

        self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        return out

class LearnableBias(nn.Module):
    def __init__(self, out_chn):
        super(LearnableBias, self).__init__()
        self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)

    def forward(self, x):
        out = x + self.bias.expand_as(x)
        return out

class TypeNDQ(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None, w_bits=1, a_bits=1):
        super(TypeNDQ, self).__init__()

        self.stride = stride
        self.inplanes = inplanes
        self.planes = planes

        self.move0c = LearnableBias(inplanes)
        self.binary_activation = QuantizationActivation(a_bits)

        if inplanes == planes:

          self.depthwiseconv3x3a = nn.Sequential(
              nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), groups=inplanes, bias=False),
              nn.BatchNorm2d(inplanes) 
           )

        elif 2* inplanes == planes :
          self.depthwiseconv3x3a = nn.Sequential(
              nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), groups=inplanes, bias=False),
              nn.BatchNorm2d(inplanes) 
           )
          self.depthwiseconv3x3b = nn.Sequential(
              nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), groups=inplanes, bias=False),
              nn.BatchNorm2d(inplanes) 
           )

        elif 4* inplanes == planes :
          self.depthwiseconv3x3a = nn.Sequential(
              nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), groups=inplanes, bias=False),
              nn.BatchNorm2d(inplanes) 
           )
          self.depthwiseconv3x3b = nn.Sequential(
              nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), groups=inplanes, bias=False),
              nn.BatchNorm2d(inplanes) 
           )
          self.depthwiseconv3x3c = nn.Sequential(
              nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), groups=inplanes, bias=False),
              nn.BatchNorm2d(inplanes) 
           )
          self.depthwiseconv3x3d = nn.Sequential(
              nn.Conv2d(inplanes, inplanes, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), groups=inplanes, bias=False),
              nn.BatchNorm2d(inplanes) 
           )


        self.shortcut = nn.Sequential()
        if stride !=1:
          self.shortcut = nn.Sequential(
            nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
          )   

        self.relu1 = nn.Sequential(
          nn.PReLU(planes),
          nn.BatchNorm2d(planes) 
        )   

        self.move1c = LearnableBias(planes)

        if planes == 16 or planes == 32:
          self.binconv3x3a = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )
        else: 
          self.binconv3x3a = nn.Sequential(
              QuantizationConv2d(planes, planes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False, w_bits=w_bits),
              nn.BatchNorm2d(planes)
          )


        self.relu2 = nn.Sequential(
            nn.PReLU(planes),
            nn.BatchNorm2d(planes) 
        )   

    def adjust_epoch(self, epoch):
        self.epoch = epoch 

    def forward(self, x):
        residual1 = x 
        out = self.move0c(x) 
        if self.inplanes == self.planes:
            out = self.depthwiseconv3x3a(out)
            out += self.shortcut(residual1)
        elif self.planes == self.inplanes * 2:
            out_1 = self.depthwiseconv3x3a(out)
            out_2 = self.depthwiseconv3x3b(out)
            out_1 += self.shortcut(residual1)
            out_2 += self.shortcut(residual1)
            out = torch.cat([out_1, out_2], dim=1)
        elif self.planes == self.inplanes * 4:
            out_1 = self.depthwiseconv3x3a(out)
            out_2 = self.depthwiseconv3x3b(out)
            out_3 = self.depthwiseconv3x3c(out)
            out_4 = self.depthwiseconv3x3d(out)
            out_1 += self.shortcut(residual1)
            out_3 += self.shortcut(residual1)
            out_3 += self.shortcut(residual1)
            out_4 += self.shortcut(residual1)
            out = torch.cat([out_1, out_2, out_3, out_4], dim=1)

        out = self.relu1(out)
        
        residual2 = out
        out = self.move1c(out) 
        out = self.binary_activation(out)
        out = self.binconv3x3a(out)
        out += residual2
        out = self.relu2(out)

        return out


class qb_net_small(nn.Module):
    def __init__(self, w_bits, a_bits, num_classes=1000, ):
        super(qb_net_small, self).__init__()
        self.feature = nn.ModuleList()
        for i in range(len(stage_out_channel)):
            if i == 0:
                self.feature.append(firstconv3x3(3, stage_out_channel[i], 2))
            elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 16: 
                self.feature.append(TypeNDQ(stage_out_channel[i-1], stage_out_channel[i], 2, w_bits=w_bits, a_bits=a_bits))
            else: 
                self.feature.append(TypeNDQ(stage_out_channel[i-1], stage_out_channel[i], 1, w_bits=w_bits, a_bits=a_bits))

        self.pool1 = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        for i, block in enumerate(self.feature):
            x = block(x)

        x = self.pool1(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x



